/*
 * Copyright 2022 The Blaze Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.spark.sql.execution.blaze.arrowio.util;

import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DateDayVector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.NullVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TimeStampMicroTZVector;
import org.apache.arrow.vector.TimeStampMicroVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.UInt1Vector;
import org.apache.arrow.vector.UInt2Vector;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.UInt8Vector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.complex.FixedSizeListVector;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.holders.NullableVarCharHolder;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarArray;
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.UTF8String;

/**
 * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not
 * supported.
 */
public final class ArrowColumnVector extends ColumnVector {

    private final ArrowVectorAccessor accessor;
    private ArrowColumnVector[] childColumns;

    @Override
    public boolean hasNull() {
        return accessor.getNullCount() > 0;
    }

    @Override
    public int numNulls() {
        return accessor.getNullCount();
    }

    @Override
    public void close() {
        if (childColumns != null) {
            for (int i = 0; i < childColumns.length; i++) {
                childColumns[i].close();
                childColumns[i] = null;
            }
            childColumns = null;
        }
        accessor.close();
    }

    @Override
    public boolean isNullAt(int rowId) {
        return accessor.isNullAt(rowId);
    }

    @Override
    public boolean getBoolean(int rowId) {
        return accessor.getBoolean(rowId);
    }

    @Override
    public byte getByte(int rowId) {
        return accessor.getByte(rowId);
    }

    @Override
    public short getShort(int rowId) {
        return accessor.getShort(rowId);
    }

    @Override
    public int getInt(int rowId) {
        return accessor.getInt(rowId);
    }

    @Override
    public long getLong(int rowId) {
        return accessor.getLong(rowId);
    }

    @Override
    public float getFloat(int rowId) {
        return accessor.getFloat(rowId);
    }

    @Override
    public double getDouble(int rowId) {
        return accessor.getDouble(rowId);
    }

    @Override
    public Decimal getDecimal(int rowId, int precision, int scale) {
        if (isNullAt(rowId)) return null;
        return accessor.getDecimal(rowId, precision, scale);
    }

    @Override
    public UTF8String getUTF8String(int rowId) {
        if (isNullAt(rowId)) return null;
        return accessor.getUTF8String(rowId);
    }

    @Override
    public byte[] getBinary(int rowId) {
        if (isNullAt(rowId)) return null;
        return accessor.getBinary(rowId);
    }

    @Override
    public ColumnarArray getArray(int rowId) {
        if (isNullAt(rowId)) return null;
        return accessor.getArray(rowId);
    }

    @Override
    public ColumnarMap getMap(int rowId) {
        if (isNullAt(rowId)) return null;
        return accessor.getMap(rowId);
    }

    @Override
    public ArrowColumnVector getChild(int ordinal) {
        return childColumns[ordinal];
    }

    public ArrowColumnVector(ValueVector vector) {
        super(ArrowUtils.fromArrowField(vector.getField()));

        if (vector instanceof BitVector) {
            accessor = new BooleanAccessor((BitVector) vector);
        } else if (vector instanceof UInt1Vector) {
            accessor = new UInt1Accessor((UInt1Vector) vector);
        } else if (vector instanceof UInt2Vector) {
            accessor = new UInt2Accessor((UInt2Vector) vector);
        } else if (vector instanceof UInt4Vector) {
            accessor = new UInt4Accessor((UInt4Vector) vector);
        } else if (vector instanceof UInt8Vector) {
            accessor = new UInt8Accessor((UInt8Vector) vector);
        } else if (vector instanceof TinyIntVector) {
            accessor = new ByteAccessor((TinyIntVector) vector);
        } else if (vector instanceof SmallIntVector) {
            accessor = new ShortAccessor((SmallIntVector) vector);
        } else if (vector instanceof IntVector) {
            accessor = new IntAccessor((IntVector) vector);
        } else if (vector instanceof BigIntVector) {
            accessor = new LongAccessor((BigIntVector) vector);
        } else if (vector instanceof Float4Vector) {
            accessor = new FloatAccessor((Float4Vector) vector);
        } else if (vector instanceof Float8Vector) {
            accessor = new DoubleAccessor((Float8Vector) vector);
        } else if (vector instanceof DecimalVector) {
            accessor = new DecimalAccessor((DecimalVector) vector);
        } else if (vector instanceof VarCharVector) {
            accessor = new StringAccessor((VarCharVector) vector);
        } else if (vector instanceof VarBinaryVector) {
            accessor = new BinaryAccessor((VarBinaryVector) vector);
        } else if (vector instanceof DateDayVector) {
            accessor = new DateAccessor((DateDayVector) vector);
        } else if (vector instanceof TimeStampMicroVector) {
            accessor = new TimestampAccessor((TimeStampMicroVector) vector);
        } else if (vector instanceof TimeStampMicroTZVector) {
            accessor = new TimestampTZAccessor((TimeStampMicroTZVector) vector);
        } else if (vector instanceof MapVector) {
            MapVector mapVector = (MapVector) vector;
            accessor = new MapAccessor(mapVector);
        } else if (vector instanceof ListVector) {
            ListVector listVector = (ListVector) vector;
            accessor = new ArrayAccessor(listVector);
        } else if (vector instanceof FixedSizeListVector) {
            FixedSizeListVector listVector = (FixedSizeListVector) vector;
            accessor = new FixedSizeArrayAccessor(listVector);
        } else if (vector instanceof StructVector) {
            StructVector structVector = (StructVector) vector;
            accessor = new StructAccessor(structVector);

            childColumns = new ArrowColumnVector[structVector.size()];
            for (int i = 0; i < childColumns.length; ++i) {
                childColumns[i] = new ArrowColumnVector(structVector.getVectorById(i));
            }
        } else if (vector instanceof NullVector) {
            accessor = new NullAccessor((NullVector) vector);
        } else {
            throw new UnsupportedOperationException("unsupported vector type: " + vector.getClass());
        }
    }

    private abstract static class ArrowVectorAccessor {

        private final ValueVector vector;

        ArrowVectorAccessor(ValueVector vector) {
            this.vector = vector;
        }

        boolean isNullAt(int rowId) {
            if (vector.getValueCount() > 0 && vector.getValidityBuffer().capacity() == 0) {
                return false;
            } else {
                return vector.isNull(rowId);
            }
        }

        final int getNullCount() {
            return vector.getNullCount();
        }

        final void close() {
            vector.close();
        }

        boolean getBoolean(int rowId) {
            throw new UnsupportedOperationException(this.getClass().getName());
        }

        byte getByte(int rowId) {
            throw new UnsupportedOperationException(this.getClass().getName());
        }

        short getShort(int rowId) {
            throw new UnsupportedOperationException(this.getClass().getName());
        }

        int getInt(int rowId) {
            throw new UnsupportedOperationException(this.getClass().getName());
        }

        long getLong(int rowId) {
            throw new UnsupportedOperationException(this.getClass().getName());
        }

        float getFloat(int rowId) {
            throw new UnsupportedOperationException(this.getClass().getName());
        }

        double getDouble(int rowId) {
            throw new UnsupportedOperationException(this.getClass().getName());
        }

        Decimal getDecimal(int rowId, int precision, int scale) {
            throw new UnsupportedOperationException(this.getClass().getName());
        }

        UTF8String getUTF8String(int rowId) {
            throw new UnsupportedOperationException(this.getClass().getName());
        }

        byte[] getBinary(int rowId) {
            throw new UnsupportedOperationException(this.getClass().getName());
        }

        ColumnarArray getArray(int rowId) {
            throw new UnsupportedOperationException(this.getClass().getName());
        }

        ColumnarMap getMap(int rowId) {
            throw new UnsupportedOperationException(this.getClass().getName());
        }
    }

    private static class NullAccessor extends ArrowVectorAccessor {

        private final NullVector accessor;

        NullAccessor(NullVector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        boolean isNullAt(int rowId) {
            return true;
        }
    }

    private static class BooleanAccessor extends ArrowVectorAccessor {

        private final BitVector accessor;

        BooleanAccessor(BitVector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final boolean getBoolean(int rowId) {
            return accessor.get(rowId) == 1;
        }
    }

    private static class ByteAccessor extends ArrowVectorAccessor {

        private final TinyIntVector accessor;

        ByteAccessor(TinyIntVector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final byte getByte(int rowId) {
            return accessor.get(rowId);
        }
    }

    private static class UInt1Accessor extends ArrowVectorAccessor {

        private final UInt1Vector accessor;

        UInt1Accessor(UInt1Vector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final byte getByte(int rowId) {
            return accessor.get(rowId);
        }
    }

    private static class UInt2Accessor extends ArrowVectorAccessor {

        private final UInt2Vector accessor;

        UInt2Accessor(UInt2Vector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final short getShort(int rowId) {
            return (short) accessor.get(rowId);
        }
    }

    private static class UInt4Accessor extends ArrowVectorAccessor {

        private final UInt4Vector accessor;

        UInt4Accessor(UInt4Vector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final int getInt(int rowId) {
            return accessor.get(rowId);
        }
    }

    private static class UInt8Accessor extends ArrowVectorAccessor {

        private final UInt8Vector accessor;

        UInt8Accessor(UInt8Vector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final long getLong(int rowId) {
            return accessor.get(rowId);
        }
    }

    private static class ShortAccessor extends ArrowVectorAccessor {

        private final SmallIntVector accessor;

        ShortAccessor(SmallIntVector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final short getShort(int rowId) {
            return accessor.get(rowId);
        }
    }

    private static class IntAccessor extends ArrowVectorAccessor {

        private final IntVector accessor;

        IntAccessor(IntVector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final int getInt(int rowId) {
            return accessor.get(rowId);
        }
    }

    private static class LongAccessor extends ArrowVectorAccessor {

        private final BigIntVector accessor;

        LongAccessor(BigIntVector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final long getLong(int rowId) {
            return accessor.get(rowId);
        }
    }

    private static class FloatAccessor extends ArrowVectorAccessor {

        private final Float4Vector accessor;

        FloatAccessor(Float4Vector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final float getFloat(int rowId) {
            return accessor.get(rowId);
        }
    }

    private static class DoubleAccessor extends ArrowVectorAccessor {

        private final Float8Vector accessor;

        DoubleAccessor(Float8Vector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final double getDouble(int rowId) {
            return accessor.get(rowId);
        }
    }

    private static class DecimalAccessor extends ArrowVectorAccessor {

        private final DecimalVector accessor;

        DecimalAccessor(DecimalVector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final Decimal getDecimal(int rowId, int precision, int scale) {
            if (isNullAt(rowId)) return null;
            return Decimal.apply(accessor.getObject(rowId), precision, scale);
        }
    }

    private static class StringAccessor extends ArrowVectorAccessor {

        private final VarCharVector accessor;
        private final NullableVarCharHolder stringResult = new NullableVarCharHolder();

        StringAccessor(VarCharVector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final UTF8String getUTF8String(int rowId) {
            accessor.get(rowId, stringResult);
            if (stringResult.isSet == 0) {
                return null;
            } else {
                return UTF8String.fromAddress(
                        null,
                        stringResult.buffer.memoryAddress() + stringResult.start,
                        stringResult.end - stringResult.start);
            }
        }
    }

    private static class BinaryAccessor extends ArrowVectorAccessor {

        private final VarBinaryVector accessor;

        BinaryAccessor(VarBinaryVector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final byte[] getBinary(int rowId) {
            return accessor.getObject(rowId);
        }
    }

    private static class DateAccessor extends ArrowVectorAccessor {

        private final DateDayVector accessor;

        DateAccessor(DateDayVector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final int getInt(int rowId) {
            return accessor.get(rowId);
        }
    }

    private static class TimestampAccessor extends ArrowVectorAccessor {

        private final TimeStampMicroVector accessor;

        TimestampAccessor(TimeStampMicroVector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final long getLong(int rowId) {
            return accessor.get(rowId);
        }
    }

    private static class TimestampTZAccessor extends ArrowVectorAccessor {

        private final TimeStampMicroTZVector accessor;

        TimestampTZAccessor(TimeStampMicroTZVector vector) {
            super(vector);
            this.accessor = vector;
        }

        @Override
        final long getLong(int rowId) {
            return accessor.get(rowId);
        }
    }

    private static class ArrayAccessor extends ArrowVectorAccessor {

        private final ListVector accessor;
        private final ArrowColumnVector arrayData;

        ArrayAccessor(ListVector vector) {
            super(vector);
            this.accessor = vector;
            this.arrayData = new ArrowColumnVector(vector.getDataVector());
        }

        @Override
        final ColumnarArray getArray(int rowId) {
            int start = accessor.getElementStartIndex(rowId);
            int end = accessor.getElementEndIndex(rowId);
            return new ColumnarArray(arrayData, start, end - start);
        }
    }

    private static class FixedSizeArrayAccessor extends ArrowVectorAccessor {

        private final FixedSizeListVector accessor;
        private final ArrowColumnVector arrayData;

        FixedSizeArrayAccessor(FixedSizeListVector vector) {
            super(vector);
            this.accessor = vector;
            this.arrayData = new ArrowColumnVector(vector.getDataVector());
        }

        @Override
        final ColumnarArray getArray(int rowId) {
            int start = accessor.getElementStartIndex(rowId);
            int end = accessor.getElementEndIndex(rowId);
            return new ColumnarArray(arrayData, start, end - start);
        }
    }

    /**
     * Any call to "get" method will throw UnsupportedOperationException.
     *
     * <p>Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses
     * getStruct() method defined in the parent class. Any call to "get" method in this class is a
     * bug in the code.
     */
    private static class StructAccessor extends ArrowVectorAccessor {

        StructAccessor(StructVector vector) {
            super(vector);
        }
    }

    private static class MapAccessor extends ArrowVectorAccessor {
        private final MapVector accessor;
        private final ArrowColumnVector keys;
        private final ArrowColumnVector values;

        MapAccessor(MapVector vector) {
            super(vector);
            this.accessor = vector;
            StructVector entries = (StructVector) vector.getDataVector();
            this.keys = new ArrowColumnVector(entries.getChild(MapVector.KEY_NAME));
            this.values = new ArrowColumnVector(entries.getChild(MapVector.VALUE_NAME));
        }

        @Override
        final ColumnarMap getMap(int rowId) {
            int index = rowId * MapVector.OFFSET_WIDTH;
            int offset = accessor.getOffsetBuffer().getInt(index);
            int length = accessor.getInnerValueCountAt(rowId);
            return new ColumnarMap(keys, values, offset, length);
        }
    }
}
